Skip to content

[Common] Persistent Grouped MXFP8 quantization kernel#2738

Open
Oleg-Goncharov wants to merge 58 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_persistent_grouped_mxfp8_kernel
Open

[Common] Persistent Grouped MXFP8 quantization kernel#2738
Oleg-Goncharov wants to merge 58 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_persistent_grouped_mxfp8_kernel

Conversation

@Oleg-Goncharov
Copy link
Collaborator

@Oleg-Goncharov Oleg-Goncharov commented Mar 5, 2026

Description

This PR adds a persistent grouped MXFP8 quantization kernel with static scheduling.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added persistent kernel
  • Added TunableConfig structure to tune performance

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Oleg-Goncharov Oleg-Goncharov added enhancement New feature or request MoE labels Mar 5, 2026
@Oleg-Goncharov Oleg-Goncharov requested a review from ptrendx March 5, 2026 16:18
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 5, 2026

Greptile Summary

This PR introduces a persistent grouped MXFP8 quantization kernel with a static grid-stride scheduler. The kernel replaces the previous one-CTA-per-block dispatch with a persistent grid (sm_count × STATIC_PERSISTENT_BLOCKS_PER_SM blocks) that iterates over a virtual work grid, enabling better SM utilization for workloads with large numbers of small tensors (including zero-row tensors in the middle of a group). A TunableConfig struct centralises all performance knobs, ShapeRepresentation is promoted to a compile-time template parameter (eliminating runtime branching inside the hot loop), and a USE_FAST_MATH path is added that uses native BF16/FP16 FMA instructions for higher throughput on SM 10.0+ hardware.

Key changes:

  • Persistent scheduler: JobDescriptor / BlockDescriptor abstractions decouple logical block coordinates from physical CTA IDs; advance_to_next_job drives the grid-stride loop.
  • Empty-tensor safety: is_job_valid returns true for zero-row/zero-col tensor slots and the outer loop skips them, enabling CUDA-graph-safe group tensors with trailing or interleaved empty slots.
  • Fast-math path: New mul_cvt_4x PTX overloads (BF16/FP16 scale × BF16/FP16 data → FP8) added to ptx.cuh; FPx4 gains alignas(4*sizeof(T)) required for the b64 shared-memory loads used in these paths.
  • API: nvte_group_quantize_v2 exposes QuantizationConfig (including use_fast_math) to callers.
  • Refactoring: ShapeRepresentation consolidated into utils.cuh; two new compile-time dispatch macros (TRANSFORMER_ENGINE_SCALING_TYPE_SWITCH, TRANSFORMER_ENGINE_GROUP_TENSOR_SHAPE_REPRESENTATION_SWITCH) reduce kernel branching at the cost of larger binary size.

Issues found:

  • The kernel-internal USE_FAST_MATH && !NON_FP32_CAST_ONLY guard silently returns without writing any output, whereas it should fire a device error so that invalid configurations surface immediately.
  • is_job_valid accesses offsets_ptr[tensor_id + 1] for all non-SAME_BOTH_DIMS shapes; this relies on offsets_ptr having num_tensors + 1 entries, an invariant that is not documented or asserted anywhere near the call site.
  • The test change to compare outputs as a flat (1, elts_num) shape instead of the original (rows, cols) silently skips scale validation for all rows beyond the first 32.
  • STATIC_PERSISTENT_BLOCKS_PER_SM = 24 likely exceeds the hardware limit of ~16 concurrent blocks per SM (2048 threads / 128 threads-per-block), and no compile-time guard enforces that derived TunableConfig combinations remain valid.

Confidence Score: 3/5

  • The persistent kernel logic and barrier management are sound, but a silent no-op path for USE_FAST_MATH + float input and a reduced test coverage for scale validation deserve attention before merging.
  • The core algorithmic changes are carefully constructed—barrier parity is self-consistent across persistent job iterations, empty-tensor handling is correct, and the double-buffered pipeline is properly sequenced. The fast-math PTX paths are architecture-specific and guarded by CUDA_ARCH checks. The main concerns are: (1) a kernel-level guard silently drops all output instead of erroring for unsupported USE_FAST_MATH + IType combinations; (2) the test suite's switch to flat (1×N) comparison silently stops validating scale rows beyond the first; (3) the offsets_ptr[tensor_id+1] assumption in is_job_valid is undocumented; and (4) STATIC_PERSISTENT_BLOCKS_PER_SM=24 over-provisions the SM for typical hardware thread limits.
  • transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh (kernel guard and is_job_valid invariant), tests/cpp/operator/test_cast_mxfp8_grouped.cu (scale comparison coverage).

Important Files Changed

Filename Overview
transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh Core kernel file with the new persistent grouped MXFP8 quantization kernel. Adds TunableConfig, JobDescriptor/BlockDescriptor structs, decode_job/decode_block/is_job_valid helpers, a grid-stride persistent scheduler, and USE_FAST_MATH code path with BF16/FP16 native FMA. Scale-units mismatch in rowwise_scale_is_within_bounds check (scale index vs. data columns) is a pre-existing issue not fully resolved in this PR.
transformer_engine/common/util/ptx.cuh Adds alignas(4*sizeof(T)) to FPx4 struct and adds four new mul_cvt_4x overloads that accept BF16 or FP16 scale values using native mixed-precision FMA. The alignment fix is necessary for the b64 shared-memory loads used in the fast-math inline assembly paths.
transformer_engine/common/cast/dispatch/quantize.cuh Threads a QuantizationConfig pointer through group_quantize_fwd_helper and group_quantize_bwd_helper into the mxfp8::group_quantize call. Forward and backward paths updated consistently.
transformer_engine/common/utils.cuh ShapeRepresentation enum consolidated here from cast/core/common.cuh and hadamard_transform/graph_safe_group_hadamard_transform.cu, making it a shared type in the common header. Clean refactor.
transformer_engine/common/common.h Two new dispatch macros: TRANSFORMER_ENGINE_SCALING_TYPE_SWITCH and TRANSFORMER_ENGINE_GROUP_TENSOR_SHAPE_REPRESENTATION_SWITCH. These lift ScalingType and ShapeRepresentation to compile-time template arguments for the kernel. Both macros include a default error case. Well-formed.
tests/cpp/operator/test_cast_mxfp8_grouped.cu Adds use_fast_math parameter to the test suite, tests empty-tensor-in-middle cases, uses nvte_group_quantize_v2 for CAST_ONLY, and switches to flat (1×elts_num) comparison to avoid mismatches from graph-safe trailing-garbage regions. The reference simulation correctly models fast-math truncation.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[nvte_group_quantize_v2] --> B[group_quantize_fwd_helper\nwith QuantizationConfig]
    B --> C{Scaling Type}
    C -->|MXFP8_1D_SCALING| D[mxfp8::group_quantize]
    D --> E[Validate fast_math\nconstraints]
    E --> F{is_single_tensor?}
    F -->|No| G[update_tma_descriptors\n1 thread per tensor block\nSkips empty rows/cols tensors]
    F -->|Yes| H[Use static TMA maps]
    G --> I
    H --> I[Compute work_blocks_X/Y]
    I --> J{PERSISTENT mode?}
    J -->|Yes| K[grid = sm_count × 24 × 1]
    J -->|No| L[grid = work_blocks_X × work_blocks_Y]
    K --> M[group_quantize_mxfp8_kernel\nSHAPE_REP compile-time template\nUSE_FAST_MATH compile-time template]
    L --> M
    M --> N[Init barriers\nBUFFS_NUM=2]
    N --> O{job_finished?}
    O -->|No| P[decode_job → JobDescriptor\nis_job_valid check]
    P --> Q{job_has_work?}
    Q -->|No: empty tensor| R[advance_to_next_job\ncontinue]
    R --> O
    Q -->|Yes| S[decode_block → BlockDescriptor\nfence_acquire_tensormap if new tensor]
    S --> T[Prime pipeline:\nprefetch stage 0]
    T --> U[Process STAGES=4 slices\ncolwise + rowwise per stage\nDouble-buffered TMA]
    U --> V{IS_DBIAS?\nis_single_tensor?}
    V -->|Yes| W[Write partial dbias\nto workspace]
    V -->|No| X
    W --> X[advance_to_next_job]
    X --> O
    O -->|Done| Y[atomicMaxFloat amax_ptr\ndestroy_barriers]
    Y --> Z{IS_DBIAS?}
    Z -->|Yes| AA[grouped_reduce_dbias]
    Z -->|No| AB[Done]
    AA --> AB
Loading

Last reviewed commit: "Merge branch 'main' ..."

@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_persistent_grouped_mxfp8_kernel branch from 924ff91 to 325181b Compare March 6, 2026 10:39
}

const float *const thread_in_base = dbias_partial + dbias_in_offset_Y * cols + thread_id * nvec;
OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output stride assumes uniform cols across all tensors

The output write offset is computed as:

OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec;

where cols is last_logical_dim — a single value shared across all tensors in the group. This is correct for SAME_BOTH_DIMS and VARYING_FIRST_DIM (where all tensors share the same last dimension), but the kernel receives shape_rep as a parameter and does not enforce that restriction.

For VARYING_LAST_DIM or VARYING_BOTH_DIMS where per-tensor cols differ, the fixed tensor_id * cols stride would compute wrong output offsets. Currently, tests skip dbias validation for these cases, but the kernel would produce incorrect results if actually called with varying-last-dim tensors.

Consider adding a device-side assertion to enforce the precondition:

Suggested change
OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec;
if (shape_rep != ShapeRepresentation::SAME_BOTH_DIMS && shape_rep != ShapeRepresentation::VARYING_FIRST_DIM) {
NVTE_DEVICE_ERROR("group_reduce_dbias_kernel requires uniform last dimensions across tensors");
}

Oleg-Goncharov and others added 15 commits March 10, 2026 11:58
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@ptrendx ptrendx force-pushed the pr_persistent_grouped_mxfp8_kernel branch from 5815335 to aa484a3 Compare March 10, 2026 19:07
const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;

const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong units in rowwise_scale_is_within_bounds guard

scales_offset_X_rowwise is a scale index (one entry per 32-element column group), while cols is the number of data columns. Comparing them directly means the guard almost never fires.

Concretely, with cols = 96 and SCALE_DIM_X = 32:

  • scales_offset_X_rowwise for the four threads of the last (and only) X-block is {0, 1, 2, 3}
  • Valid scale positions covering real data: {0, 1, 2} (covering columns 0–31, 32–63, 64–95)
  • The current check 3 < 96 evaluates to true, so thread 3 still writes a spurious scale for the nonexistent columns 96–127

The correct comparison multiplies the scale index back to column units:

Suggested change
const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols;
const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise * SCALE_DIM_X < cols;

This correctly excludes scale index 3 because 3 * 32 = 96, which is not < 96.

Comment on lines +174 to +191
__device__ __forceinline__ JobDescriptor decode_job(
const ShapeRepresentation shape_rep, const bool is_single_tensor, const size_t num_tensors,
const size_t first_logical_dim, const size_t last_logical_dim, const size_t work_blocks_X,
const int32_t ctaid_X, const int32_t ctaid_Y, const int64_t *const __restrict__ offsets_ptr,
const int64_t *const __restrict__ first_dims_ptr,
const int64_t *const __restrict__ last_dims_ptr) {
JobDescriptor job{};
job.block_id = ctaid_Y * work_blocks_X + ctaid_X;
job.block_global_offset = is_single_tensor
? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X)
: (job.block_id * ELTS_PER_CHUNK);
job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset, ctaid_Y,
first_logical_dim, last_logical_dim, offsets_ptr);
job.rows =
get_tensor_rows_num(job.tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
job.cols = get_tensor_cols_num(job.tensor_id, shape_rep, last_logical_dim, last_dims_ptr);
return job;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a constructor of the JobDescriptor struct (you can make the constructor __device__ too).

Comment on lines +218 to +232
__device__ __forceinline__ BlockDescriptor
decode_block(const JobDescriptor &job, const bool is_single_tensor,
const int64_t *const __restrict__ offsets_ptr) {
BlockDescriptor block{};
block.tensor_base = is_single_tensor ? 0 : static_cast<size_t>(offsets_ptr[job.tensor_id]);
const size_t CHUNK_DIM_X_ = CHUNK_DIM_X;
const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, CHUNK_DIM_X_);
block.block_id_in_current_tensor =
is_single_tensor ? job.block_id : (job.block_id - block.tensor_base / ELTS_PER_CHUNK);
block.block_id_Y = block.block_id_in_current_tensor / blocks_X_num_in_current_tensor;
block.block_id_X = block.block_id_in_current_tensor % blocks_X_num_in_current_tensor;
block.block_offset_Y = block.block_id_Y * CHUNK_DIM_Y;
block.block_offset_X = block.block_id_X * CHUNK_DIM_X;
return block;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly this should be a constructor too.

const size_t global_offset_Y, const size_t buff_offset, const size_t shmem_buff_size,
uint64_t *barrier, const bool leading_thread) {
if (leading_thread) {
ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 questions - why is this done before the TMA call and why is it done only by the leading_thread? In the other parts of the code (e.g. in ptx::copy_2d_to_shared) we do transfer, then arrive_expect on the leading thread and just arrive on all the other threads.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptx::mbarrier_arrive_expect_tx is also called by a single thread in ptx::copy_2d_to_shared. I initialized the barriers using a single thread, which is sufficient for it to work. But we can also keep the previous approach, where all threads in the block participate explicitly. And since the async copy and expect_tx are in the same phase, it’s also valid to issue expect_tx first.

Comment on lines +714 to +716
if (launch_block_id >= total_work_blocks) {
return;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this possible?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, for small input tensors where total_work_blocks is less than SMs * K, with K = STATIC_PERSISTENT_BLOCKS_PER_SM

last_logical_dim, work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr,
first_dims_ptr, last_dims_ptr);
allow_next_job_prefetch =
is_job_valid(prefetch_job, shape_rep, total_work_blocks, offsets_ptr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are prevalidating the next job here, then why do we need earlier the check if the job we are about to do is going to be valid and draining it if it is not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefetch the first stage of the next CTA at the end of processing the current CTA. This check is only to avoid copying data for null blocks. The main termination check, i.e., when to stop processing the current chunk and exit the loop is at line 770.

is_job_valid(current_job, shape_rep, total_work_blocks, offsets_ptr);
if (!current_job_is_valid) {
if (has_prefetched_current_job) {
// A stage-0 prefetch may already be in flight for this CTA. Drain it before exiting.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to drain it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We destroy the barriers after exiting the loop. But this invalidation can be done once the mbarrier objects are guaranteed to have completed their current phase (drained). Otherwise, the TMA engine may finish the copy and attempt to call complete on an already invalidated mbarrier

@ptrendx ptrendx marked this pull request as draft March 12, 2026 16:26
Oleg-Goncharov and others added 2 commits March 13, 2026 17:08
@Oleg-Goncharov Oleg-Goncharov marked this pull request as ready for review March 18, 2026 14:07
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_persistent_grouped_mxfp8_kernel branch from 7c41a6a to 6874935 Compare March 18, 2026 14:24
Oleg-Goncharov and others added 3 commits March 18, 2026 15:25
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_persistent_grouped_mxfp8_kernel branch from 7bdc696 to 5068556 Compare March 18, 2026 14:34
pre-commit-ci bot and others added 6 commits March 18, 2026 14:34
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
{VARYING_LAST_DIM, 3, 256,896, 128,256,512},
{VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256},
// Empty tensor in the middle of the group must not terminate the persistent work loop.
{VARYING_BOTH_DIMS, 3, 1,(128*128)+(128*128), 128,0,128, 128,0,128},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a test case for only varying first dim that contains as group of size zero

Oleg-Goncharov and others added 10 commits March 18, 2026 18:29
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov
Copy link
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request MoE

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants